package idgenerator

import (
	"database/sql"
	"fmt"
	"gitee.com/glsoft/go-id-generator/util"
	_ "github.com/go-sql-driver/mysql"
	"sync"
)

type MySQLIdGenerator struct {
	db        *sql.DB
	key       string //id generator key name
	currentId int64  //current id
	batchMax  int64  //max id till get from mysql
	batch     int64  //get batch count ids from mysql once
	lock      sync.Mutex
}

func newMySQLIdGenerator(db *sql.DB, section string) (*MySQLIdGenerator, error) {
	idGenerator := new(MySQLIdGenerator)
	idGenerator.db = db
	if len(section) == 0 {
		return nil, fmt.Errorf("section is nil")
	}
	err := idGenerator.SetSection(section)
	if err != nil {
		return nil, err
	}

	idGenerator.batch = BatchCount
	idGenerator.currentId = 0
	idGenerator.batchMax = idGenerator.currentId
	return idGenerator, nil
}

func (m *MySQLIdGenerator) SetSection(key string) error {
	m.key = key
	return nil
}

func (m *MySQLIdGenerator) GetCurrentId() int64 {
	m.lock.Lock()
	defer m.lock.Unlock()
	return m.currentId
}

func (m *MySQLIdGenerator) SetCurrentId(id int64) {
	m.lock.Lock()
	defer m.lock.Unlock()
	m.currentId = id
}

func (m *MySQLIdGenerator) GetBatchMax() int64 {
	m.lock.Lock()
	defer m.lock.Unlock()
	return m.batchMax
}

func (m *MySQLIdGenerator) SetBatchMax(max int64) {
	m.lock.Lock()
	defer m.lock.Unlock()
	m.batchMax = max
}

func (m *MySQLIdGenerator) getNextBatch() error {
	var id int64
	var haveValue bool
	selectForUpdate := fmt.Sprintf(SelectForUpdate, KeyRecordTableName, m.key)
	tx, err := m.db.Begin()
	if err != nil {
		return err
	}

	rows, err := tx.Query(selectForUpdate)
	if err != nil {
		_ = tx.Rollback()
		return err
	}
	defer util.CloseRows(rows)
	for rows.Next() {
		err := rows.Scan(&id)
		if err != nil {
			_ = tx.Rollback()
			return err
		}
		haveValue = true
	}

	if haveValue == false {
		return fmt.Errorf("%s:have no id key", m.key)
	}

	updateIdSql := fmt.Sprintf(UpdateKeySQLFormat, KeyRecordTableName, m.batch, m.key)
	_, err = tx.Exec(updateIdSql)
	if err != nil {
		_ = tx.Rollback()
		return err
	}
	_ = tx.Commit()

	m.batchMax = id + m.batch
	m.currentId = id
	return nil
}

func (m *MySQLIdGenerator) Next() (int64, error) {
	m.lock.Lock()
	defer m.lock.Unlock()
	if m.batchMax < m.currentId+1 {
		err := m.getNextBatch()
		if err != nil {
			return 0, err
		}
	}
	m.currentId++
	return m.currentId, nil
}

func (m *MySQLIdGenerator) Init() error {

	m.lock.Lock()
	defer m.lock.Unlock()

	return m.getNextBatch()
}
